{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 中文地址解析"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. 模型配置"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "bert 102 100 119\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import time\n",
    "import pickle\n",
    "import random\n",
    "import sklearn\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from   torch import nn\n",
    "from   tqdm import tqdm\n",
    "\n",
    "from transformers import AdamW\n",
    "from transformers import get_linear_schedule_with_warmup\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda\")\n",
    "else:\n",
    "    device = torch.device(\"cpu\")\n",
    "max_train_epochs            = 5\n",
    "warmup_proportion           = 0.05\n",
    "gradient_accumulation_steps = 1\n",
    "train_batch_size = 16\n",
    "valid_batch_size = train_batch_size\n",
    "test_batch_size  = train_batch_size\n",
    "data_workers     = 2\n",
    "# save_checkpoint = False\n",
    "\n",
    "learning_rate  = 2e-5\n",
    "weight_decay   = 0.01\n",
    "max_grad_norm  = 1.0\n",
    "\n",
    "\n",
    "base_path    = './'\n",
    "model_select = \"bert\"\n",
    "\n",
    "from transformers import BertConfig, BertTokenizer, BertModel, BertForTokenClassification\n",
    "cls_token  = '[CLS]'\n",
    "eos_token  = '[SEP]'\n",
    "unk_token  = '[UNK]'\n",
    "pad_token  = '[PAD]'\n",
    "mask_token = '[MASK]'\n",
    "\n",
    "tokenizer  = BertTokenizer.from_pretrained('bert-base-chinese')\n",
    "config     = BertConfig.from_pretrained('bert-base-chinese')\n",
    "TheModel   = BertModel\n",
    "ModelForTokenClassification = BertForTokenClassification\n",
    "\n",
    "eos_id    = tokenizer.convert_tokens_to_ids([eos_token])[0]\n",
    "unk_id    = tokenizer.convert_tokens_to_ids([unk_token])[0]\n",
    "period_id = tokenizer.convert_tokens_to_ids(['.'])[0]\n",
    "print(model_select, eos_id, unk_id, period_id)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. 标签"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "46\n"
     ]
    }
   ],
   "source": [
    "labels = ['B-assist', 'I-assist', 'B-cellno', 'I-cellno', 'B-city', 'I-city', 'B-community', 'I-community', \n",
    "        'B-country', 'I-country', 'B-devZone', 'I-devZone', 'B-district', 'I-district', 'B-floorno', \n",
    "        'I-floorno', 'B-houseno', 'I-houseno', 'B-otherinfo', 'I-otherinfo', 'B-person', 'I-person', \n",
    "        'B-poi', 'I-poi', 'B-prov', 'I-prov', 'B-redundant', 'I-redundant', 'B-road', 'I-road', \n",
    "        'B-roadno', 'I-roadno', 'B-roomno', 'I-roomno', 'B-subRoad', 'I-subRoad', 'B-subRoadno', \n",
    "        'I-subRoadno', 'B-subpoi', 'I-subpoi', 'B-subroad', 'I-subroad', 'B-subroadno', 'I-subroadno', \n",
    "        'B-town', 'I-town']\n",
    "label2id = {}\n",
    "for i, l in enumerate(labels):\n",
    "    label2id[l] = i\n",
    "num_labels = len(labels)\n",
    "print(num_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. 载入数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8957 2985 2985\n",
      "max_token_len 76\n"
     ]
    }
   ],
   "source": [
    "def get_data_list(f):\n",
    "    data_list = []\n",
    "    origin_token, token, label = [], [], []\n",
    "    for l in f:\n",
    "        l = l.strip().split()\n",
    "        if not l:\n",
    "            data_list.append([token, label, origin_token])\n",
    "            origin_token, token, label = [], [], []\n",
    "            continue\n",
    "        for i, tok in enumerate(l[0]):\n",
    "            token.append(tok)\n",
    "            label.append(label2id[l[1]])\n",
    "        origin_token.append(l[0])\n",
    "    assert len(token) == 0\n",
    "    return data_list\n",
    "\n",
    "f_train = open(base_path + 'train.txt')\n",
    "f_test  = open(base_path + 'test.txt')\n",
    "f_dev   = open(base_path + 'dev.txt')\n",
    "\n",
    "train_list = get_data_list(f_train)\n",
    "test_list  = get_data_list(f_test)\n",
    "dev_list   = get_data_list(f_dev)\n",
    "\n",
    "print(len(train_list), len(test_list), len(dev_list))\n",
    "max_token_len = 0\n",
    "for ls in [train_list, test_list, dev_list]:\n",
    "    for l in ls:\n",
    "        max_token_len = max(max_token_len, len(l[0]))\n",
    "print('max_token_len', max_token_len)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. 数据加载器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DataSet(torch.utils.data.Dataset):\n",
    "    def __init__(self, examples):\n",
    "        self.examples = examples\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.examples)\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        example  = self.examples[index]\n",
    "        sentence = example[0]\n",
    "        label    = example[1]\n",
    "        \n",
    "        pad_len   = max_token_len - len(sentence)\n",
    "        total_len = len(sentence)+2\n",
    "        \n",
    "        input_token    = [cls_token] + sentence + [eos_token] + [pad_token] * pad_len\n",
    "        input_ids      = tokenizer.convert_tokens_to_ids(input_token)\n",
    "        attention_mask = [1] + [1] * len(sentence) + [1] + [0] * pad_len\n",
    "\n",
    "        label = [-100] + label + [-100] + [-100] * pad_len\n",
    "        # assert max_token_len + 2 == len(input_ids) == len(attention_mask) == len(input_token)\n",
    "        \n",
    "        return input_ids, attention_mask, total_len, label, index\n",
    "    \n",
    "\n",
    "def collate_fn(batch):\n",
    "    total_lens     = [b[2] for b in batch]\n",
    "    total_len      = max(total_lens)\n",
    "    input_ids      = torch.LongTensor([b[0] for b in batch])\n",
    "    attention_mask = torch.LongTensor([b[1] for b in batch])\n",
    "    label          = torch.LongTensor([b[3] for b in batch])\n",
    "    input_ids      = input_ids[:,:total_len]\n",
    "    attention_mask = attention_mask[:,:total_len]\n",
    "    label          = label[:,:total_len]\n",
    "\n",
    "    indexs = [b[4] for b in batch]\n",
    "    return input_ids, attention_mask, label, indexs\n",
    "\n",
    "train_dataset = DataSet(train_list)\n",
    "train_loader  = torch.utils.data.DataLoader(\n",
    "                    train_dataset,\n",
    "                    batch_size=train_batch_size,\n",
    "                    shuffle = True,\n",
    "                    num_workers=data_workers,\n",
    "                    collate_fn=collate_fn,\n",
    "                )\n",
    "\n",
    "test_dataset = DataSet(test_list)\n",
    "test_loader  = torch.utils.data.DataLoader(\n",
    "                    train_dataset,\n",
    "                    batch_size=train_batch_size,\n",
    "                    shuffle = False,\n",
    "                    num_workers=data_workers,\n",
    "                    collate_fn=collate_fn,\n",
    "                )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. 模型定义"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BertForSeqTagging(ModelForTokenClassification):\n",
    "    def __init__(self):\n",
    "        super().__init__(config)\n",
    "        self.num_labels = num_labels\n",
    "        self.bert       = TheModel.from_pretrained('bert-base-chinese')\n",
    "        self.dropout    = torch.nn.Dropout(config.hidden_dropout_prob)\n",
    "        self.classifier = torch.nn.Linear(config.hidden_size, num_labels)\n",
    "        self.init_weights()\n",
    "        \n",
    "    def forward(self, input_ids, attention_mask, labels=None):\n",
    "        outputs         = self.bert(input_ids=input_ids, attention_mask=attention_mask)\n",
    "        sequence_output = outputs[0]\n",
    "        \n",
    "        sequence_output = self.dropout(sequence_output)\n",
    "        logits          = self.classifier(sequence_output)\n",
    "        active_loss     = attention_mask.reshape(-1) == 1\n",
    "        active_logits   = logits.view(-1, self.num_labels)[active_loss]\n",
    "\n",
    "        if labels is not None:\n",
    "            loss_fct = torch.nn.CrossEntropyLoss()\n",
    "            active_labels = labels.reshape(-1)[active_loss]\n",
    "            loss = loss_fct(active_logits, active_labels)\n",
    "            return loss\n",
    "        else:\n",
    "            return active_logits\n",
    "        \n",
    "model = BertForSeqTagging()\n",
    "model.to(device)\n",
    "t_total = len(train_loader) // gradient_accumulation_steps * max_train_epochs + 1\n",
    "\n",
    "num_warmup_steps = int(warmup_proportion * t_total)\n",
    "\n",
    "no_decay = ['bias', 'LayerNorm.weight'] # no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']\n",
    "param_optimizer = list(model.named_parameters())\n",
    "optimizer_grouped_parameters = [\n",
    "    {'params':[p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],'weight_decay': weight_decay},\n",
    "    {'params':[p for n, p in param_optimizer if any(nd in n for nd in no_decay)],'weight_decay': 0.0}\n",
    "]\n",
    "optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, correct_bias=False)\n",
    "scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=t_total)\n",
    "\n",
    "\n",
    "\n",
    "def eval():\n",
    "    result = []\n",
    "    for step, batch in enumerate(tqdm(test_loader)):\n",
    "        input_ids, attention_mask, label = (b.to(device) for b in batch[:-1])\n",
    "        with torch.no_grad():\n",
    "            logits = model(input_ids, attention_mask)\n",
    "            logits = F.softmax(logits, dim=-1)\n",
    "        logits = logits.data.cpu()\n",
    "        logit_list = []\n",
    "        sum_len = 0\n",
    "        for m in attention_mask:\n",
    "            l = m.sum().cpu().item()\n",
    "            logit_list.append(logits[sum_len:sum_len+l])\n",
    "            sum_len += l\n",
    "        assert sum_len == len(logits)\n",
    "        for i, l in enumerate(logit_list):\n",
    "            rr = torch.argmax(l, dim=1)\n",
    "            for j, w in enumerate(test_list[batch[-1][i]][0]):\n",
    "                result.append([w, labels[label[i][j+1].cpu().item()],labels[rr[j+1]]])\n",
    "            result.append([])\n",
    "    print(result[:20])\n",
    "    return result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. 模型训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/560 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  1%|▏         | 7/560 [00:31<22:34,  2.45s/it]  "
     ]
    },
    {
     "ename": "",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
     ]
    }
   ],
   "source": [
    "for epoch in range(max_train_epochs):\n",
    "    epoch_loss = None\n",
    "    epoch_step = 0\n",
    "    start_time = time.time()\n",
    "    model.train()\n",
    "    for step, batch in enumerate(tqdm(train_loader)):\n",
    "        input_ids, attention_mask, label = (b.to(device) for b in batch[:-1])\n",
    "        loss = model(input_ids, attention_mask, label)\n",
    "        loss.backward()\n",
    "#         torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)\n",
    "        if (step + 1) % gradient_accumulation_steps == 0:\n",
    "            optimizer.step()\n",
    "            scheduler.step() \n",
    "            optimizer.zero_grad()\n",
    "            \n",
    "        if epoch_loss is None:\n",
    "            epoch_loss = loss.item()\n",
    "        else:\n",
    "            epoch_loss = 0.98*epoch_loss + 0.02*loss.item()\n",
    "        epoch_step += 1\n",
    "    \n",
    "    used_time = (time.time() - start_time)/60\n",
    "    \n",
    "    print('Epoch = %d Epoch Mean Loss %.4f Time %.2f min' % (epoch, epoch_loss, used_time))\n",
    "    result = eval()\n",
    "    with open('result.txt', 'w') as f:\n",
    "        for r in result:\n",
    "            f.write('\\t'.join(r) + '\\n')\n",
    "    y_true = []\n",
    "    y_pred = []\n",
    "    for r in result:\n",
    "        if not r: continue\n",
    "        y_true.append(label2id[r[1]])\n",
    "        y_pred.append(label2id[r[2]])\n",
    "    print(sklearn.metrics.f1_score(y_true, y_pred, average='micro'))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
